# Copyright 2018, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A compiler for the test backend."""

import collections
from collections.abc import Callable

import numpy as np
import tensorflow as tf

from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.backends.mapreduce import intrinsics as mapreduce_intrinsics
from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_building_block_factory
from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_computation_factory
from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_tree_transformations
from tensorflow_federated.python.core.impl.compiler import building_block_factory
from tensorflow_federated.python.core.impl.compiler import building_blocks
from tensorflow_federated.python.core.impl.compiler import intrinsic_defs
from tensorflow_federated.python.core.impl.computation import computation_impl
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.impl.types import type_conversions


def _ensure_structure(
    int_or_structure, int_or_structure_type, possible_struct_type
):
  if isinstance(
      int_or_structure_type, computation_types.StructType
  ) or not isinstance(possible_struct_type, computation_types.StructType):
    return int_or_structure
  else:
    # Broadcast int_or_structure to the same structure as the struct type
    return structure.map_structure(
        lambda *args: int_or_structure, possible_struct_type
    )


def _get_secure_intrinsic_reductions() -> dict[
    str,
    Callable[
        [building_blocks.ComputationBuildingBlock],
        building_blocks.ComputationBuildingBlock,
    ],
]:
  """Returns map from intrinsic to reducing function.

  The returned dictionary is a `collections.OrderedDict` which maps intrinsic
  URIs to functions from building-block intrinsic arguments to an implementation
  of the intrinsic call which has been reduced to a smaller, more fundamental
  set of intrinsics.

  WARNING: the reductions returned here will produce computation bodies that do
  **NOT** perform the crypto protocol. This method is intended only for testing
  settings.

  Bodies generated by later dictionary entries will not contain references
  to intrinsics whose entries appear earlier in the dictionary. This property
  is useful for simple reduction of an entire computation by iterating through
  the map of intrinsics, substituting calls to each.
  """

  def federated_secure_sum(arg):
    py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock)
    summand_arg = building_blocks.Selection(arg, index=0)
    summand_type = summand_arg.type_signature.member  # pytype: disable=attribute-error
    max_input_arg = building_blocks.Selection(arg, index=1)
    max_input_type = max_input_arg.type_signature

    # Add the max_value as a second value in the zero, so it can be read during
    # `accumulate` to ensure client summands are valid. The value will be
    # later dropped in `report`.
    #
    # While accumulating summands, we'll assert each summand is less than or
    # equal to max_input. Otherwise the comptuation should issue an error.
    summation_zero = tensorflow_building_block_factory.create_generic_constant(
        summand_type, 0
    )
    aggregation_zero = building_blocks.Struct(
        [summation_zero, max_input_arg], container_type=tuple
    )

    def assert_less_equal_max_and_add(summation_and_max_input, summand):
      summation, original_max_input = summation_and_max_input
      max_input = _ensure_structure(
          original_max_input, max_input_type, summand_type
      )

      # Assert that all coordinates in all tensors are less than the secure sum
      # allowed max input value.
      def assert_all_coordinates_less_equal(x, m):
        return tf.Assert(
            tf.reduce_all(
                tf.less_equal(tf.cast(x, np.int64), tf.cast(m, np.int64))
            ),
            [
                'client value larger than maximum specified for secure sum',
                x,
                'not less than or equal to',
                m,
            ],
        )

      assert_ops = structure.flatten(
          structure.map_structure(
              assert_all_coordinates_less_equal, summand, max_input
          )
      )
      with tf.control_dependencies(assert_ops):
        return (
            structure.map_structure(tf.add, summation, summand),
            original_max_input,
        )

    assert_less_equal_and_add_proto, assert_less_equal_and_add_type = (
        tensorflow_computation_factory.create_binary_operator(
            assert_less_equal_max_and_add,
            operand_type=aggregation_zero.type_signature,
            second_operand_type=summand_type,
        )
    )
    assert_less_equal_and_add = building_blocks.CompiledComputation(
        assert_less_equal_and_add_proto,
        type_signature=assert_less_equal_and_add_type,
    )

    def nested_plus(a, b):
      return structure.map_structure(tf.add, a, b)

    plus_proto, plus_type = (
        tensorflow_computation_factory.create_binary_operator(
            nested_plus, operand_type=aggregation_zero.type_signature
        )
    )
    plus_op = building_blocks.CompiledComputation(
        plus_proto, type_signature=plus_type
    )

    # In the `report` function we take the summation and drop the second element
    # of the struct (which was holding the max_value).
    drop_max_value_proto, drop_max_value_type = (
        tensorflow_computation_factory.create_unary_operator(
            lambda x: type_conversions.type_to_py_container(x[0], summand_type),
            aggregation_zero.type_signature,
        )
    )
    drop_max_value_op = building_blocks.CompiledComputation(
        drop_max_value_proto, type_signature=drop_max_value_type
    )

    return building_block_factory.create_federated_aggregate(
        summand_arg,
        aggregation_zero,
        assert_less_equal_and_add,
        plus_op,
        drop_max_value_op,
    )

  def federated_secure_sum_bitwidth(arg):
    py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock)
    summand_arg = building_blocks.Selection(arg, index=0)
    bitwidth_arg = building_blocks.Selection(arg, index=1)

    # Comptue the max_input value from the provided bitwidth.
    def max_input_from_bitwidth(bitwidth):
      # Secure sum is performed with int64, which has 63 bits, and we need at
      # least one bit to hold the summation of two client values.
      max_secure_sum_bitwidth = 62

      def compute_max_input(bits):
        assert_op = tf.Assert(
            tf.less_equal(bits, max_secure_sum_bitwidth),
            [
                bits,
                f'is greater than maximum bitwidth {max_secure_sum_bitwidth}',
            ],
        )
        with tf.control_dependencies([assert_op]):
          return (
              tf.math.pow(tf.constant(2, tf.int64), tf.cast(bits, tf.int64)) - 1
          )

      return structure.map_structure(compute_max_input, bitwidth)

    proto, type_signature = (
        tensorflow_computation_factory.create_unary_operator(
            max_input_from_bitwidth, bitwidth_arg.type_signature
        )
    )
    compute_max_value_op = building_blocks.CompiledComputation(
        proto, type_signature=type_signature
    )

    max_value = building_blocks.Call(compute_max_value_op, bitwidth_arg)
    return federated_secure_sum(
        building_blocks.Struct([summand_arg, max_value])
    )

  def federated_secure_modular_sum(arg):
    py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock)
    if not isinstance(arg.type_signature, computation_types.StructType):
      raise ValueError(
          f'Expected a `tff.StructType`, found {arg.type_signature}.'
      )
    if isinstance(arg.type_signature, computation_types.StructWithPythonType):
      container_type = arg.type_signature.python_container
    else:
      container_type = None
    summand_arg = building_blocks.Selection(arg, index=0)
    raw_summed_values = building_block_factory.create_federated_sum(summand_arg)

    unplaced_modulus = building_blocks.Selection(arg, index=1)
    placed_modulus = building_block_factory.create_federated_value(
        unplaced_modulus, placements.SERVER
    )
    modulus_arg = building_block_factory.create_federated_zip(
        building_blocks.Struct(
            [raw_summed_values, placed_modulus], container_type=container_type
        )
    )

    def map_structure_mod(summed_values, modulus):
      modulus = _ensure_structure(
          modulus,
          unplaced_modulus.type_signature,
          raw_summed_values.type_signature.member,  # pytype: disable=attribute-error
      )
      return structure.map_structure(tf.math.mod, summed_values, modulus)

    proto, type_signature = (
        tensorflow_computation_factory.create_binary_operator(
            map_structure_mod,
            operand_type=raw_summed_values.type_signature.member,  # pytype: disable=attribute-error
            second_operand_type=placed_modulus.type_signature.member,  # pytype: disable=attribute-error
        )
    )
    modulus_fn = building_blocks.CompiledComputation(
        proto, type_signature=type_signature
    )
    modulus_computed = building_block_factory.create_federated_apply(
        modulus_fn, modulus_arg
    )

    return modulus_computed

  def federated_secure_select(arg):
    py_typecheck.check_type(arg, building_blocks.ComputationBuildingBlock)
    client_keys_arg = building_blocks.Selection(arg, index=0)
    max_key_arg = building_blocks.Selection(arg, index=1)
    server_val_arg = building_blocks.Selection(arg, index=2)
    select_fn_arg = building_blocks.Selection(arg, index=3)
    return building_block_factory.create_federated_select(
        client_keys_arg,
        max_key_arg,
        server_val_arg,
        select_fn_arg,
        secure=False,
    )

  secure_intrinsic_bodies_by_uri = collections.OrderedDict([
      (
          intrinsic_defs.FEDERATED_SECURE_SUM_BITWIDTH.uri,
          federated_secure_sum_bitwidth,
      ),
      (
          mapreduce_intrinsics.FEDERATED_SECURE_MODULAR_SUM.uri,
          federated_secure_modular_sum,
      ),
      (intrinsic_defs.FEDERATED_SECURE_SUM.uri, federated_secure_sum),
      (intrinsic_defs.FEDERATED_SECURE_SELECT.uri, federated_secure_select),
  ])
  return secure_intrinsic_bodies_by_uri


def _replace_secure_intrinsics_with_insecure_bodies(comp):
  """Iterates over all secure intrinsic bodies, inlining the intrinsics.

  This function operates on the AST level; meaning, it takes in a
  `building_blocks.ComputationBuildingBlock` as an argument and
  returns one as well. `replace_intrinsics_with_bodies` is intended to be the
  standard reduction function, which will reduce all currently implemented
  intrinsics to their bodies.

  Notice that the success of this function depends on the contract of
  `intrinsic_bodies.get_intrinsic_bodies`, that the dict returned by that
  function is ordered from more complex intrinsic to less complex intrinsics.

  Args:
    comp: Instance of `building_blocks.ComputationBuildingBlock` in which we
      wish to replace all intrinsics with their bodies.

  Returns:
    Instance of `building_blocks.ComputationBuildingBlock` with all
    the intrinsics from `intrinsic_bodies.py` inlined with their bodies, along
    with a Boolean indicating whether there was any inlining in fact done.

  Raises:
    TypeError: If the types don't match.
  """
  py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
  secure_bodies = _get_secure_intrinsic_reductions()
  transformed = False
  for uri, body in secure_bodies.items():
    comp, uri_found = tensorflow_tree_transformations.reduce_intrinsic(
        comp, uri, body
    )
    transformed = transformed or uri_found
  return comp, transformed


def replace_secure_intrinsics_with_bodies(comp):
  """Replace `secure_...` intrinsics with insecure TensorFlow equivalents.

  Designed for use in tests, this function replaces
  `tff.federated_secure_{sum, sum_bitwidth, modular_sum}` usages with equivalent
  TensorFlow computations. The resulting computation can then be run on TFF
  runtimes which do not implement secure computation.

  Args:
    comp: The computation to transform.

  Returns:
    `comp` with secure intrinsics replaced with insecure TensorFlow equivalents.
  """
  # Compile secure_sum and secure_sum_bitwidth intrinsics to insecure
  # TensorFlow computations for testing purposes.
  replaced_intrinsic_bodies, _ = (
      _replace_secure_intrinsics_with_insecure_bodies(comp.to_building_block())
  )
  return computation_impl.ConcreteComputation.from_building_block(
      replaced_intrinsic_bodies
  )
